#pragma once
#include <flashinfer/attention/default_decode_params.cuh>
#include <flashinfer/attention/variants.cuh>

// using namespace flashinfer;
// avoid "at::Layout" is ambiguous error
using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ dtype_idx }};

constexpr bool USE_SLIDING_WINDOW = {{ use_sliding_window }};
constexpr bool USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }};
constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }};

constexpr int QO_TILE_LEN = {{ qo_tile_len }};

using Params = flashinfer::BatchDecodeParamsMLA<DTypeQ, DTypeKV, DTypeO, IdType>;
using AttentionVariant =
    flashinfer::DefaultAttention</*use_custom_mask=*/false, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, /*use_alibi*/false>;
